package cn.swing.main.srv.cv.utils;

import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.Set;

/**
 * An implemetation of the Kuhn–Munkres assignment algorithm of the year 1957.
 * https://en.wikipedia.org/wiki/Hungarian_algorithm
 *
 * @author https://github.com/aalmi | march 2014
 * @version 1.0
 */
public class HungarianAlgorithm {

    double[][] matrix; // initial matrix (cost matrix)

    // markers in the matrix
    int[] squareInRow, squareInCol, rowIsCovered, colIsCovered, staredZeroesInRow;

    public HungarianAlgorithm(double[][] costMatrix) {
        // 1. 代价矩阵预处理
        double[][] adjustedMatrix = new double[costMatrix.length][costMatrix[0].length];
        for (int i = 0; i < costMatrix.length; i++) {
            for (int j = 0; j < costMatrix[0].length; j++) {
                adjustedMatrix[i][j] = costMatrix[i][j];
            }
        }

        // 2. 矩阵填充（保证方阵）
        int dim = Math.max(costMatrix.length, costMatrix[0].length);
        double[][] squareMatrix = new double[dim][dim];
        for (int i = 0; i < dim; i++) {
            Arrays.fill(squareMatrix[i], 0);
            if (i < costMatrix.length) {
                System.arraycopy(adjustedMatrix[i], 0, squareMatrix[i], 0, costMatrix[0].length);
            }
        }

        this.matrix = squareMatrix;

        squareInRow = new int[dim];       // squareInRow & squareInCol indicate the position
        squareInCol = new int[dim];    // of the marked zeroes

        rowIsCovered = new int[dim];      // indicates whether a row is covered
        colIsCovered = new int[dim];   // indicates whether a column is covered
        staredZeroesInRow = new int[dim]; // storage for the 0*
        Arrays.fill(staredZeroesInRow, -1);
        Arrays.fill(squareInRow, -1);
        Arrays.fill(squareInCol, -1);
    }

    /**
     * find an optimal assignment
     *
     * @return optimal assignment
     */
    public int[][] matchTracks() {
        step1();    // reduce matrix
        step2();    // mark independent zeroes
        step3();    // cover columns which contain a marked zero

        while (!allColumnsAreCovered()) {
            int[] mainZero = step4();
            while (mainZero == null) {      // while no zero found in step4
                step7();
                mainZero = step4();
            }
            if (squareInRow[mainZero[0]] == -1) {
                // there is no square mark in the mainZero line
                step6(mainZero);
                step3();    // cover columns which contain a marked zero
            } else {
                // there is square mark in the mainZero line
                // step 5
                rowIsCovered[mainZero[0]] = 1;  // cover row of mainZero
                colIsCovered[squareInRow[mainZero[0]]] = 0;  // uncover column of mainZero
                step7();
            }
        }

        int[][] optimalAssignment = new int[matrix.length][];
        while (true) {
            boolean breakFlag = true;
            // 计算每行的0的数量和每列的0的数量，将零元素最少得行列选出来，标记上。
            int[] rowZeroNumber = new int[matrix.length];
            int[] colZeroNumber = new int[matrix[0].length];
            for (int i = 0; i < matrix.length; i++) {
                for (int j = 0; j < matrix[0].length; j++) {
                    if (matrix[i][j] == 0) {
                        breakFlag = false;
                        rowZeroNumber[i]++;
                        colZeroNumber[j]++;
                    }
                }
            }
            if (breakFlag) {
                break;
            }
            int minRow = -1;
            int minCol = -1;
            int minRowNumber = Integer.MAX_VALUE;
            int minColNumber = Integer.MAX_VALUE;
            for (int i = 0; i < rowZeroNumber.length; i++) {
                if (rowZeroNumber[i] < minRowNumber && rowZeroNumber[i] != 0) {
                    minRowNumber = rowZeroNumber[i];
                    minRow = i;
                }
            }
            for (int j = 0; j < colZeroNumber.length; j++) {
                if (colZeroNumber[j] < minColNumber && colZeroNumber[j] != 0) {
                    minColNumber = colZeroNumber[j];
                    minCol = j;
                }
            }
            if (minRowNumber < minColNumber) {
                for (int j = 0; j < matrix[0].length; j++) {
                    if (matrix[minRow][j] == 0) {
                        optimalAssignment[minRow] = new int[] {minRow, j};
                        // matrix矩阵中minRow行所有的值和j列所有的值 置为-1
                        for (int row = 0; row < matrix.length; row++) {
                            if (matrix[row][j] == 0) {
                                matrix[row][j] = -1;
                            }
                        }
                        for (int col = 0; col < matrix.length; col++) {
                            if (matrix[minRow][col] == 0) {
                                matrix[minRow][col] = -1;
                            }
                        }
                        break;
                    }
                }
            } else {
                for (int i = 0; i < matrix.length; i++) {
                    if (matrix[i][minCol] == 0) {
                        optimalAssignment[i] = new int[] {i, minCol};
                        // matrix矩阵中i行所有的值和minCol列所有的值 置为-1
                        for (int col = 0; col < matrix[0].length; col++) {
                            if (matrix[i][col] == 0) {
                                matrix[i][col] = -1;
                            }
                        }
                        for (int row = 0; row < matrix.length; row++) {
                            if (matrix[row][minCol] == 0) {
                                matrix[row][minCol] = -1;
                            }
                        }
                        break;
                    }
                }
            }
        }
        return optimalAssignment;
    }

    /**
     * Check if all columns are covered. If that's the case then the
     * optimal solution is found
     *
     * @return true or false
     */
    private boolean allColumnsAreCovered() {
        for (int i : colIsCovered) {
            if (i == 0) {
                return false;
            }
        }
        return true;
    }

    /**
     * Step 1:
     * Reduce the matrix so that in each row and column at least one zero exists:
     * 1. subtract each row minima from each element of the row
     * 2. subtract each column minima from each element of the column
     */
    private void step1() {
        // rows
        for (int i = 0; i < matrix.length; i++) {
            // find the min value of the current row
            double currentRowMin = Double.MAX_VALUE;
            for (int j = 0; j < matrix[i].length; j++) {
                if (matrix[i][j] < currentRowMin) {
                    currentRowMin = matrix[i][j];
                }
            }
            // subtract min value from each element of the current row
            if (currentRowMin != 0) {
                for (int k = 0; k < matrix[i].length; k++) {
                    matrix[i][k] -= currentRowMin;
                }
            }
        }

        // cols
        for (int i = 0; i < matrix[0].length; i++) {
            // find the min value of the current column
            double currentColMin = Double.MAX_VALUE;
            for (int j = 0; j < matrix.length; j++) {
                if (matrix[j][i] < currentColMin) {
                    currentColMin = matrix[j][i];
                }
            }
            // subtract min value from each element of the current column
            if (currentColMin != 0) {
                for (int k = 0; k < matrix.length; k++) {
                    matrix[k][i] -= currentColMin;
                }
            }
        }
    }

    /**
     * Step 2:
     * mark each 0 with a "square", if there are no other marked zeroes in the same row or column
     */
    private void step2() {
        int[] rowHasSquare = new int[matrix.length];
        int[] colHasSquare = new int[matrix[0].length];

        for (int i = 0; i < matrix.length; i++) {
            for (int j = 0; j < matrix[0].length; j++) {
                // mark if current value == 0 & there are no other marked zeroes in the same row or column
                if (matrix[i][j] == 0 && rowHasSquare[i] == 0 && colHasSquare[j] == 0) {
                    rowHasSquare[i] = 1;
                    colHasSquare[j] = 1;
                    squareInRow[i] = j; // save the row-position of the zero
                    squareInCol[j] = i; // save the column-position of the zero
                }
            }
        }
    }

    /**
     * Step 3:
     * Cover all columns which are marked with a "square"
     */
    private void step3() {
        for (int i = 0; i < squareInCol.length; i++) {
            colIsCovered[i] = squareInCol[i] != -1 ? 1 : 0;
        }
    }

    /**
     * Step 7:
     * 1. Find the smallest uncovered value in the matrix.
     * 2. Subtract it from all uncovered values
     * 3. Add it to all twice-covered values
     */
    private void step7() {
        // Find the smallest uncovered value in the matrix
        double minUncoveredValue = Double.MAX_VALUE;
        for (int i = 0; i < matrix.length; i++) {
            if (rowIsCovered[i] == 1) {
                continue;
            }
            for (int j = 0; j < matrix[0].length; j++) {
                if (colIsCovered[j] == 0 && matrix[i][j] < minUncoveredValue) {
                    minUncoveredValue = matrix[i][j];
                }
            }
        }

        if (minUncoveredValue > 0) {
            for (int i = 0; i < matrix.length; i++) {
                for (int j = 0; j < matrix[0].length; j++) {
                    if (rowIsCovered[i] == 1 && colIsCovered[j] == 1) {
                        // Add min to all twice-covered values
                        matrix[i][j] += minUncoveredValue;
                    } else if (rowIsCovered[i] == 0 && colIsCovered[j] == 0) {
                        // Subtract min from all uncovered values
                        matrix[i][j] -= minUncoveredValue;
                    }
                }
            }
        }
    }

    /**
     * Step 4:
     * Find zero value Z_0 and mark it as "0*".
     *
     * @return position of Z_0 in the matrix
     */
    private int[] step4() {
        for (int i = 0; i < matrix.length; i++) {
            if (rowIsCovered[i] == 0) {
                for (int j = 0; j < matrix[i].length; j++) {
                    if (matrix[i][j] == 0 && colIsCovered[j] == 0) {
                        staredZeroesInRow[i] = j; // mark as 0*
                        return new int[] {i, j};
                    }
                }
            }
        }
        return null;
    }

    /**
     * Step 6:
     * Create a chain K of alternating "squares" and "0*"
     *
     * @param mainZero => Z_0 of Step 4
     */
    private void step6(int[] mainZero) {
        int i = mainZero[0];
        int j = mainZero[1];

        Set<int[]> K = new LinkedHashSet<>();
        //(a)
        // add Z_0 to K
        K.add(mainZero);
        boolean found = false;
        do {
            // (b)
            // add Z_1 to K if
            // there is a zero Z_1 which is marked with a "square " in the column of Z_0
            if (squareInCol[j] != -1) {
                K.add(new int[] {squareInCol[j], j});
                found = true;
            } else {
                found = false;
            }

            // if no zero element Z_1 marked with "square" exists in the column of Z_0, then cancel the loop
            if (!found) {
                break;
            }

            // (c)
            // replace Z_0 with the 0* in the row of Z_1
            i = squareInCol[j];
            j = staredZeroesInRow[i];
            // add the new Z_0 to K
            if (j != -1) {
                K.add(new int[] {i, j});
                found = true;
            } else {
                found = false;
            }

        } while (found); // (d) as long as no new "square" marks are found

        // (e)
        for (int[] zero : K) {
            // remove all "square" marks in K
            if (squareInCol[zero[1]] == zero[0]) {
                squareInCol[zero[1]] = -1;
                squareInRow[zero[0]] = -1;
            }
            // replace the 0* marks in K with "square" marks
            if (staredZeroesInRow[zero[0]] == zero[1]) {
                squareInRow[zero[0]] = zero[1];
                squareInCol[zero[1]] = zero[0];
            }
        }

        // (f)
        // remove all marks
        Arrays.fill(staredZeroesInRow, -1);
        Arrays.fill(rowIsCovered, 0);
        Arrays.fill(colIsCovered, 0);
    }

}